{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "135213af-05b4-4361-999c-93af84e4a429",
   "metadata": {},
   "source": [
    "# Chapter 4 - 文本分类\n",
    "- 使用 表示类模型（BERT/特征抽取类）进行分类\n",
    "- 使用 生成模型进行分类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "537d8749",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"HF_HOME\"] = \"/openbayes/home/huggingface\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "205c1891-8b66-4aab-91eb-2e376fadc5fd",
   "metadata": {},
   "source": [
    "## 加载数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0365f405-ac57-4447-a7b2-3d66d8346b41",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T13:11:51.020847Z",
     "iopub.status.busy": "2024-10-11T13:11:51.020296Z",
     "iopub.status.idle": "2024-10-11T13:12:02.743365Z",
     "shell.execute_reply": "2024-10-11T13:12:02.742883Z",
     "shell.execute_reply.started": "2024-10-11T13:11:51.020797Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['text', 'label'],\n",
       "        num_rows: 8530\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['text', 'label'],\n",
       "        num_rows: 1066\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['text', 'label'],\n",
       "        num_rows: 1066\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "# 使用最经典的情感 分类\n",
    "data = load_dataset(\"rotten_tomatoes\")\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "32ffd40b-170e-49af-8a54-f2e0035a02d3",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T09:40:29.727984Z",
     "iopub.status.busy": "2024-10-11T09:40:29.727808Z",
     "iopub.status.idle": "2024-10-11T09:40:29.731719Z",
     "shell.execute_reply": "2024-10-11T09:40:29.731207Z",
     "shell.execute_reply.started": "2024-10-11T09:40:29.727966Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'text': ['the rock is destined to be the 21st century\\'s new \" conan \" and that he\\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .',\n",
       "  'things really get weird , though not particularly scary : the movie is all portent and no content .'],\n",
       " 'label': [1, 0]}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data[\"train\"][0, -1]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09489b40-bc3b-45b5-9c2a-67516e52191b",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T09:31:20.657026Z",
     "iopub.status.busy": "2024-10-11T09:31:20.656534Z",
     "iopub.status.idle": "2024-10-11T09:31:20.661607Z",
     "shell.execute_reply": "2024-10-11T09:31:20.660646Z",
     "shell.execute_reply.started": "2024-10-11T09:31:20.656983Z"
    },
    "tags": []
   },
   "source": [
    "## 表征类模型(Representation Models)\n",
    "### 使用和任务相关的模型(task-specific model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5ef719a5-c58f-45e9-82b5-3cc974b2ed3d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T09:43:13.361973Z",
     "iopub.status.busy": "2024-10-11T09:43:13.361551Z",
     "iopub.status.idle": "2024-10-11T09:43:14.906139Z",
     "shell.execute_reply": "2024-10-11T09:43:14.905550Z",
     "shell.execute_reply.started": "2024-10-11T09:43:13.361942Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at cardiffnlp/twitter-roberta-base-sentiment-latest were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
      "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "# https://huggingface.co/cardiffnlp/twitter-roberta-base-sentiment-latest\n",
    "# Labels: 0 -> Negative; 1 -> Neutral; 2 -> Positive\n",
    "# Path to our HF model\n",
    "model_path = \"cardiffnlp/twitter-roberta-base-sentiment-latest\"\n",
    "\n",
    "# Load model into pipeline\n",
    "pipe = pipeline(\n",
    "    model=model_path,\n",
    "    tokenizer=model_path,\n",
    "    return_all_scores=True,\n",
    "    device=\"cuda:0\"\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "986286d4-ffe6-4ed8-9116-3408aa9af312",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T09:43:22.243797Z",
     "iopub.status.busy": "2024-10-11T09:43:22.243353Z",
     "iopub.status.idle": "2024-10-11T09:43:29.744759Z",
     "shell.execute_reply": "2024-10-11T09:43:29.744143Z",
     "shell.execute_reply.started": "2024-10-11T09:43:22.243762Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1066/1066 [00:07<00:00, 142.24it/s]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from transformers.pipelines.pt_utils import KeyDataset\n",
    "\n",
    "# Run inference\n",
    "y_pred = []\n",
    "for output in tqdm(pipe(KeyDataset(data[\"test\"], \"text\")), total=len(data[\"test\"])):\n",
    "    # pipe 会自动使用 模型内置的 label 3 分类\n",
    "    negative_score = output[0][\"score\"]\n",
    "    positive_score = output[2][\"score\"]\n",
    "    assignment = np.argmax([negative_score, positive_score])\n",
    "    y_pred.append(assignment)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "b01d57de-9c29-4a1a-9dde-2f8f87b8a3cb",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T09:43:30.880142Z",
     "iopub.status.busy": "2024-10-11T09:43:30.879752Z",
     "iopub.status.idle": "2024-10-11T09:43:30.883475Z",
     "shell.execute_reply": "2024-10-11T09:43:30.883027Z",
     "shell.execute_reply.started": "2024-10-11T09:43:30.880103Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "\n",
    "def evaluate_performance(y_true, y_pred):\n",
    "    \"\"\"Create and print the classification report\"\"\"\n",
    "    performance = classification_report(\n",
    "        y_true, y_pred,\n",
    "        target_names=[\"Negative Review\", \"Positive Review\"]\n",
    "    )\n",
    "    print(performance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "d80ff3c9-a983-4300-acc3-994e1c529c88",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T09:43:38.560131Z",
     "iopub.status.busy": "2024-10-11T09:43:38.559663Z",
     "iopub.status.idle": "2024-10-11T09:43:38.579192Z",
     "shell.execute_reply": "2024-10-11T09:43:38.578698Z",
     "shell.execute_reply.started": "2024-10-11T09:43:38.560096Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                 precision    recall  f1-score   support\n",
      "\n",
      "Negative Review       0.76      0.88      0.81       533\n",
      "Positive Review       0.86      0.72      0.78       533\n",
      "\n",
      "       accuracy                           0.80      1066\n",
      "      macro avg       0.81      0.80      0.80      1066\n",
      "   weighted avg       0.81      0.80      0.80      1066\n",
      "\n"
     ]
    }
   ],
   "source": [
    "evaluate_performance(data[\"test\"][\"label\"], y_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c825d356-98dc-4dc5-b016-5c106e632fe0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T09:46:12.737877Z",
     "iopub.status.busy": "2024-10-11T09:46:12.737095Z",
     "iopub.status.idle": "2024-10-11T09:46:12.744129Z",
     "shell.execute_reply": "2024-10-11T09:46:12.743012Z",
     "shell.execute_reply.started": "2024-10-11T09:46:12.737840Z"
    },
    "tags": []
   },
   "source": [
    "#### 备注：macro/micro avg 的区别？\n",
    "- macro avg 是针对每个类别算自己的 precision/recall/f1，然后取平均\n",
    "- micro avg 先计算总体的TP、FP、FN等,然后再计算总体指标。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3ae62e8-aaaa-466c-9afd-77bbbbc81536",
   "metadata": {},
   "source": [
    "### 利用文本向量做分类\n",
    "#### 监督学习（训练分类模型）\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d97af0a1-0090-4a37-861a-f85a3f016fab",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T09:50:07.970348Z",
     "iopub.status.busy": "2024-10-11T09:50:07.969937Z",
     "iopub.status.idle": "2024-10-11T09:50:13.821704Z",
     "shell.execute_reply": "2024-10-11T09:50:13.821179Z",
     "shell.execute_reply.started": "2024-10-11T09:50:07.970318Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "86b3afa13f834fd5a17281468abc9919",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4413030770b04880a23eb424f65f4156",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7ebd237b731c4e51b85c3baa49471ff4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "README.md:   0%|          | 0.00/10.3k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d5a76708dacb48c4bcfa5773714b9401",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/output/envs/hands-on-llm/lib/python3.10/site-packages/huggingface_hub/file_download.py:1142: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ab317001e4d94ef0a68de87afe530d70",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/653 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ff930ceafc7f401ab206563132f83430",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/328M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "\n",
    "model = SentenceTransformer('sentence-transformers/all-distilroberta-v1')\n",
    "\n",
    "# Convert text to embeddings\n",
    "train_embeddings = model.encode(data[\"train\"][\"text\"], show_progress_bar=True)\n",
    "test_embeddings = model.encode(data[\"test\"][\"text\"], show_progress_bar=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "7112a4f6-7fec-48e6-aabd-09d98915af7d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T09:50:19.500386Z",
     "iopub.status.busy": "2024-10-11T09:50:19.499953Z",
     "iopub.status.idle": "2024-10-11T09:50:19.504630Z",
     "shell.execute_reply": "2024-10-11T09:50:19.504203Z",
     "shell.execute_reply.started": "2024-10-11T09:50:19.500349Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(8530, 768)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_embeddings.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ae54211-8216-40d1-9fa3-27b98a449bb0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T09:51:05.788721Z",
     "iopub.status.busy": "2024-10-11T09:51:05.788253Z",
     "iopub.status.idle": "2024-10-11T09:51:05.793375Z",
     "shell.execute_reply": "2024-10-11T09:51:05.792237Z",
     "shell.execute_reply.started": "2024-10-11T09:51:05.788685Z"
    },
    "tags": []
   },
   "source": [
    "#### 使用传统机器学习分类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "e568ade0-f445-4a6d-a2bd-4b31f32ffd66",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T09:53:54.567800Z",
     "iopub.status.busy": "2024-10-11T09:53:54.567285Z",
     "iopub.status.idle": "2024-10-11T09:53:54.762840Z",
     "shell.execute_reply": "2024-10-11T09:53:54.762230Z",
     "shell.execute_reply.started": "2024-10-11T09:53:54.567777Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                 precision    recall  f1-score   support\n",
      "\n",
      "Negative Review       0.79      0.82      0.80       533\n",
      "Positive Review       0.81      0.78      0.79       533\n",
      "\n",
      "       accuracy                           0.80      1066\n",
      "      macro avg       0.80      0.80      0.80      1066\n",
      "   weighted avg       0.80      0.80      0.80      1066\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "# Train a Logistic Regression on our train embeddings\n",
    "# 可以调参学习\n",
    "# \n",
    "clf = LogisticRegression(\n",
    "    random_state=42, \n",
    "    max_iter=1000,\n",
    "    # C=0.1,  # 正则化\n",
    ")\n",
    "clf.fit(train_embeddings, data[\"train\"][\"label\"])\n",
    "\n",
    "# Predict previously unseen instances\n",
    "y_pred = clf.predict(test_embeddings)\n",
    "evaluate_performance(data[\"test\"][\"label\"], y_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "deaae791-a126-42b9-b611-012d2e5c27a4",
   "metadata": {},
   "source": [
    "#### Zero-shot 方式分类（匹配）"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4c6d184-ec2d-4386-af2e-9b605b3c6e08",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T09:55:38.052518Z",
     "iopub.status.busy": "2024-10-11T09:55:38.052215Z",
     "iopub.status.idle": "2024-10-11T09:55:38.055560Z",
     "shell.execute_reply": "2024-10-11T09:55:38.054980Z",
     "shell.execute_reply.started": "2024-10-11T09:55:38.052498Z"
    },
    "tags": []
   },
   "source": [
    "##### 注意 1 ⚠️\n",
    "我们也可以不学习，直接让模型的表示的  Embedding 和 Label 的 Embedding 进行相似度计算？\n",
    "- step1: 把所有正(负）加起来，计算平均；这样就可以得到 正样本的 embedding （其实和 MF 的思想非常的相似）\n",
    "- step2: 计算 test sample 中样本 embedding 和正负样本 embedding 的相似度，更接近的那个就是类似\n",
    "> 本质上是把分类任务变成匹配任务"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "11bdf357-2f8d-4e88-a99b-f6d7096dbddf",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T10:00:49.314960Z",
     "iopub.status.busy": "2024-10-11T10:00:49.314458Z",
     "iopub.status.idle": "2024-10-11T10:00:49.359874Z",
     "shell.execute_reply": "2024-10-11T10:00:49.359334Z",
     "shell.execute_reply.started": "2024-10-11T10:00:49.314924Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(8530, 769) (8530, 768)\n",
      "(2, 768)\n",
      "                 precision    recall  f1-score   support\n",
      "\n",
      "Negative Review       0.74      0.79      0.76       533\n",
      "Positive Review       0.77      0.73      0.75       533\n",
      "\n",
      "       accuracy                           0.76      1066\n",
      "      macro avg       0.76      0.76      0.76      1066\n",
      "   weighted avg       0.76      0.76      0.76      1066\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.metrics import classification_report\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "\n",
    "# Average the embeddings of all documents in each target label\n",
    "df = pd.DataFrame(np.hstack([train_embeddings, np.array(data[\"train\"][\"label\"]).reshape(-1, 1)]))\n",
    "print(df.shape, train_embeddings.shape)\n",
    "averaged_target_embeddings = df.groupby(768).mean().values\n",
    "print(averaged_target_embeddings.shape)\n",
    "\n",
    "# Find the best matching embeddings between evaluation documents and target embeddings\n",
    "sim_matrix = cosine_similarity(test_embeddings, averaged_target_embeddings)\n",
    "y_pred = np.argmax(sim_matrix, axis=1)\n",
    "\n",
    "# Evaluate the model\n",
    "evaluate_performance(data[\"test\"][\"label\"], y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "f5c3c915-a1e6-4202-94fd-afd4c403d8c5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T10:01:54.366806Z",
     "iopub.status.busy": "2024-10-11T10:01:54.366178Z",
     "iopub.status.idle": "2024-10-11T10:01:54.387521Z",
     "shell.execute_reply": "2024-10-11T10:01:54.387091Z",
     "shell.execute_reply.started": "2024-10-11T10:01:54.366785Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 用另外一种方式获取 label 的 embedding\n",
    "label_embeddings = model.encode([\"A negative review\",  \"A positive review\"])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "24db26ba-2029-42e0-aae2-55aef20cd95d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T10:02:09.993546Z",
     "iopub.status.busy": "2024-10-11T10:02:09.993259Z",
     "iopub.status.idle": "2024-10-11T10:02:10.016952Z",
     "shell.execute_reply": "2024-10-11T10:02:10.016462Z",
     "shell.execute_reply.started": "2024-10-11T10:02:09.993527Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                 precision    recall  f1-score   support\n",
      "\n",
      "Negative Review       0.75      0.61      0.68       533\n",
      "Positive Review       0.67      0.80      0.73       533\n",
      "\n",
      "       accuracy                           0.71      1066\n",
      "      macro avg       0.71      0.71      0.70      1066\n",
      "   weighted avg       0.71      0.71      0.70      1066\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "\n",
    "# Find the best matching label for each document\n",
    "sim_matrix = cosine_similarity(test_embeddings, label_embeddings)\n",
    "y_pred = np.argmax(sim_matrix, axis=1)\n",
    "\n",
    "\n",
    "evaluate_performance(data[\"test\"][\"label\"], y_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ba84eba-dfe6-4c56-aa37-2069c37820b8",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T10:03:00.620987Z",
     "iopub.status.busy": "2024-10-11T10:03:00.620456Z",
     "iopub.status.idle": "2024-10-11T10:03:00.627522Z",
     "shell.execute_reply": "2024-10-11T10:03:00.626405Z",
     "shell.execute_reply.started": "2024-10-11T10:03:00.620945Z"
    },
    "tags": []
   },
   "source": [
    "##### 注意 2 ⚠️\n",
    "label 的描述可以换成其他的描述，比如  \"A very negative movie review\" 和 \"A very positive movie review\"\n",
    "\n",
    "\n",
    "这是一个实验科学，合理的发散和尝试"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb3efb5b-c8f7-459d-a486-a766f0789bc5",
   "metadata": {},
   "source": [
    "## 生成模型做分类\n",
    "### Encoder-docer 模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5aaa9c21-4bea-4a03-9f48-be199df6fc94",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T10:05:04.621719Z",
     "iopub.status.busy": "2024-10-11T10:05:04.621254Z",
     "iopub.status.idle": "2024-10-11T10:05:33.443681Z",
     "shell.execute_reply": "2024-10-11T10:05:33.443230Z",
     "shell.execute_reply.started": "2024-10-11T10:05:04.621683Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'pipeline' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;66;03m# Load our model\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m pipe \u001b[38;5;241m=\u001b[39m \u001b[43mpipeline\u001b[49m(\n\u001b[1;32m      3\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtext2text-generation\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m      4\u001b[0m     model\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgoogle/flan-t5-small\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m      5\u001b[0m     device\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcuda:0\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m      6\u001b[0m )\n\u001b[1;32m      8\u001b[0m \u001b[38;5;66;03m# Prepare our data\u001b[39;00m\n\u001b[1;32m      9\u001b[0m prompt \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIs the following sentence positive or negative? \u001b[39m\u001b[38;5;124m\"\u001b[39m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'pipeline' is not defined"
     ]
    }
   ],
   "source": [
    "# Load our model\n",
    "pipe = pipeline(\n",
    "    \"text2text-generation\",\n",
    "    model=\"google/flan-t5-small\",\n",
    "    device=\"cuda:0\"\n",
    ")\n",
    "\n",
    "# Prepare our data\n",
    "prompt = \"Is the following sentence positive or negative? \"\n",
    "data = data.map(lambda example: {\"t5\": prompt + example['text']})\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "0681f687-e623-4f24-be5d-c1d2eecbab38",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T10:06:15.713949Z",
     "iopub.status.busy": "2024-10-11T10:06:15.713735Z",
     "iopub.status.idle": "2024-10-11T10:06:41.553361Z",
     "shell.execute_reply": "2024-10-11T10:06:41.552593Z",
     "shell.execute_reply.started": "2024-10-11T10:06:15.713931Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1066/1066 [00:25<00:00, 41.26it/s]\n"
     ]
    }
   ],
   "source": [
    "# 推理\n",
    "y_pred = []\n",
    "for output in tqdm(pipe(KeyDataset(data[\"test\"], \"t5\")), total=len(data[\"test\"])):\n",
    "    text = output[0][\"generated_text\"]\n",
    "    y_pred.append(0 if text == \"negative\" else 1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "1df7f503-4bea-42c4-a7ef-4d7f13e1c9d9",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T10:06:41.554369Z",
     "iopub.status.busy": "2024-10-11T10:06:41.554205Z",
     "iopub.status.idle": "2024-10-11T10:06:41.564541Z",
     "shell.execute_reply": "2024-10-11T10:06:41.564107Z",
     "shell.execute_reply.started": "2024-10-11T10:06:41.554352Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                 precision    recall  f1-score   support\n",
      "\n",
      "Negative Review       0.83      0.85      0.84       533\n",
      "Positive Review       0.85      0.83      0.84       533\n",
      "\n",
      "       accuracy                           0.84      1066\n",
      "      macro avg       0.84      0.84      0.84      1066\n",
      "   weighted avg       0.84      0.84      0.84      1066\n",
      "\n"
     ]
    }
   ],
   "source": [
    "evaluate_performance(data[\"test\"][\"label\"], y_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4c183e1-e907-4180-aebf-1391942ecfc7",
   "metadata": {},
   "source": [
    "#### 备注\n",
    "这个方式本质上是一种提示学习(prompt learning)，Instruction Tuning和Prompt turning 方法的核心一样，就是去发掘语言模型本身具备的知识。而他们的不同点就在于，Prompt是去激发语言模型的补全能力，比如给出上半句生成下半句、或者做完形填空，都还是像在做language model任务，而Instruction Tuning则是激发语言模型的理解能力，通过给出更明显的指令，让模型去理解并做出正确的反馈\n",
    "\n",
    "不过现在提 prompt learning 已经比较少了，可以把 prompt leanring 认为是 instruction turning 的子集。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "efb2832a-7c06-4646-b22c-a32a652c2aec",
   "metadata": {},
   "source": [
    "### Decoder 模型\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "cc72ba5d-8c52-4753-a122-f7d658e45942",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T10:11:34.709625Z",
     "iopub.status.busy": "2024-10-11T10:11:34.709331Z",
     "iopub.status.idle": "2024-10-11T10:13:56.085036Z",
     "shell.execute_reply": "2024-10-11T10:13:56.084552Z",
     "shell.execute_reply.started": "2024-10-11T10:11:34.709604Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f53cf19b9f294153b90fb42525156812",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1cee4c142fee4c8492b4c26ca50af454",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "50c0e5c8e27f4b68a95f7edbdc1d1818",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/8530 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c5a3423fb2494d3e86ac0c8618d37fe5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/1066 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "97894dc23eae4e9399cf40a39581e4aa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/1066 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['text', 'label', 't5', 'gpt', 'gpt2'],\n",
       "        num_rows: 8530\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['text', 'label', 't5', 'gpt', 'gpt2'],\n",
       "        num_rows: 1066\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['text', 'label', 't5', 'gpt', 'gpt2'],\n",
       "        num_rows: 1066\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Load our model\n",
    "pipe = pipeline(\n",
    "    \"text-generation\",\n",
    "    model=\"gpt2\",\n",
    "    device=\"cuda:0\"\n",
    ")\n",
    "\n",
    "# Prepare our data\n",
    "prompt = \"Is the following sentence positive or negative? \"\n",
    "data = data.map(lambda example: {\"gpt2\": prompt + example['text']})\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "83e07815-168e-4828-95f7-ce6102b9366f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T10:11:06.370649Z",
     "iopub.status.busy": "2024-10-11T10:11:06.370366Z",
     "iopub.status.idle": "2024-10-11T10:11:32.207671Z",
     "shell.execute_reply": "2024-10-11T10:11:32.207059Z",
     "shell.execute_reply.started": "2024-10-11T10:11:06.370629Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1066/1066 [00:25<00:00, 41.28it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                 precision    recall  f1-score   support\n",
      "\n",
      "Negative Review       0.00      0.00      0.00       533\n",
      "Positive Review       0.50      1.00      0.67       533\n",
      "\n",
      "       accuracy                           0.50      1066\n",
      "      macro avg       0.25      0.50      0.33      1066\n",
      "   weighted avg       0.25      0.50      0.33      1066\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "/output/envs/hands-on-llm/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/output/envs/hands-on-llm/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/output/envs/hands-on-llm/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
     ]
    }
   ],
   "source": [
    "# 推理\n",
    "y_pred = []\n",
    "for output in tqdm(pipe(KeyDataset(data[\"test\"], \"gpt2\")), total=len(data[\"test\"])):\n",
    "    text = output[0][\"generated_text\"]\n",
    "    y_pred.append(0 if text == \"negative\" else 1)\n",
    "\n",
    "evaluate_performance(data[\"test\"][\"label\"], y_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d32bcd8a-6fad-479c-8709-19af8f65c336",
   "metadata": {},
   "source": [
    "### chatGPT 做文本分类 （一般也是 decoder）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "24500a0c-6274-41f9-9e21-469eed37d8f9",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T10:19:23.399523Z",
     "iopub.status.busy": "2024-10-11T10:19:23.398944Z",
     "iopub.status.idle": "2024-10-11T10:19:24.195267Z",
     "shell.execute_reply": "2024-10-11T10:19:24.194763Z",
     "shell.execute_reply.started": "2024-10-11T10:19:23.399484Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1'"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import openai\n",
    "\n",
    "# 如果国内不方便访问 GPT, 可以使用 deepseek 代替\n",
    "# 充值一块钱用一年。。。\n",
    "YOUR_KEY_HERE=\"sk-680f79219dbf4cce9bf796769b179dfe\"\n",
    "BASE_URL=\"https://api.deepseek.com\"\n",
    "MODEL_NAME=\"deepseek-chat\"\n",
    "\n",
    "# 海外\n",
    "# YOUR_KEY_HERE=\"YOUR_KEY_HERE\"\n",
    "# BASE_URL=\"https://api.openai.com/v1\"\n",
    "# MODEL_NAME=\"gpt-3.5-turbo-0125\"\n",
    "\n",
    "\n",
    "client = openai.OpenAI(\n",
    "    api_key=YOUR_KEY_HERE,\n",
    "    base_url=BASE_URL\n",
    ")\n",
    "\n",
    "def chatgpt_generation(prompt, document, model=MODEL_NAME):\n",
    "    \"\"\"Generate an output based on a prompt and an input document.\"\"\"\n",
    "    messages=[\n",
    "        {\n",
    "            \"role\": \"system\",\n",
    "            \"content\": \"You are a helpful assistant.\"\n",
    "            },\n",
    "        {\n",
    "            \"role\": \"user\",\n",
    "            \"content\":   prompt.replace(\"[DOCUMENT]\", document)\n",
    "            }\n",
    "    ]\n",
    "    chat_completion = client.chat.completions.create(\n",
    "      messages=messages,\n",
    "      model=model,\n",
    "      temperature=0\n",
    "    )\n",
    "    return chat_completion.choices[0].message.content\n",
    "\n",
    "# Define a prompt template as a base\n",
    "prompt = \"\"\"Predict whether the following document is a positive or negative movie review:\n",
    "\n",
    "[DOCUMENT]\n",
    "\n",
    "If it is positive return 1 and if it is negative return 0. Do not give any other answers.\n",
    "\"\"\"\n",
    "\n",
    "# Predict the target using GPT\n",
    "document = \"unpretentious , charming , quirky , original\"\n",
    "chatgpt_generation(prompt, document)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c13f29a8-a9d0-45d2-be07-f8c118d26fdb",
   "metadata": {},
   "source": [
    "下一步将是针对整个评估数据集运行OpenAI的一个模型。但是，只有当您有足够的令牌时才运行此操作，因为这将为整个测试数据集(1066条记录)调用API。\n",
    "> 如果没钱就不要🙅‍♂️运行下面的内容"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "b674eb81-2e84-4d1f-a4c2-d261db65a247",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T10:21:08.818179Z",
     "iopub.status.busy": "2024-10-11T10:21:08.817136Z",
     "iopub.status.idle": "2024-10-11T10:33:01.378764Z",
     "shell.execute_reply": "2024-10-11T10:33:01.378184Z",
     "shell.execute_reply.started": "2024-10-11T10:21:08.818137Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1066/1066 [11:52<00:00,  1.50it/s]\n"
     ]
    }
   ],
   "source": [
    "predictions = [chatgpt_generation(prompt, doc) for doc in tqdm(data[\"test\"][\"text\"])]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "8525e0f6-4a29-4fec-ac20-20b9f4751c63",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T10:33:01.380143Z",
     "iopub.status.busy": "2024-10-11T10:33:01.379908Z",
     "iopub.status.idle": "2024-10-11T10:33:01.395185Z",
     "shell.execute_reply": "2024-10-11T10:33:01.394792Z",
     "shell.execute_reply.started": "2024-10-11T10:33:01.380118Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                 precision    recall  f1-score   support\n",
      "\n",
      "Negative Review       0.89      0.95      0.92       533\n",
      "Positive Review       0.95      0.88      0.91       533\n",
      "\n",
      "       accuracy                           0.91      1066\n",
      "      macro avg       0.92      0.91      0.91      1066\n",
      "   weighted avg       0.92      0.91      0.91      1066\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Extract predictions\n",
    "y_pred = [int(pred) for pred in predictions]\n",
    "\n",
    "# Evaluate performance\n",
    "evaluate_performance(data[\"test\"][\"label\"], y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54144203-1a15-49d4-82de-da0fa5d7e4ac",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(\"hello world\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a3672d2-8488-4277-94cd-ff74ff4232f2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hands-on-llm",
   "language": "python",
   "name": "hands-on-llm"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
